{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch as tc\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hosvd(x):\n",
    "    \"\"\"\n",
    "    :param x: 待分解的张量\n",
    "    :return G: 核张量\n",
    "    :return U: 变换矩阵\n",
    "    :return lm: 各个键约化矩阵的本征谱\n",
    "    \"\"\"\n",
    "    ndim = x.ndim\n",
    "    U = list()  # 用于存取各个变换矩阵\n",
    "    lm = list()  # 用于存取各个键约化矩阵的本征谱\n",
    "    x = tc.from_numpy(x)\n",
    "    for n in range(ndim):\n",
    "        index = list(range(ndim))\n",
    "        index.pop(n)\n",
    "        _mat = tc.tensordot(x, x, [index, index])\n",
    "        _lm, _U = tc.symeig(_mat, eigenvectors=True)\n",
    "        lm.append(_lm.numpy())\n",
    "        U.append(_U)\n",
    "    # 计算核张量\n",
    "    G = tucker_product(x, U)\n",
    "    U1 = [u.numpy() for u in U]\n",
    "    return G, U1, lm\n",
    "\n",
    "\n",
    "def tucker_product(x, U, dim=1):\n",
    "    \"\"\"\n",
    "    :param x: 张量\n",
    "    :param U: 变换矩阵\n",
    "    :param dim: 收缩各个矩阵的第几个指标\n",
    "    :return G: 返回Tucker乘积的结果\n",
    "    \"\"\"\n",
    "    ndim = x.ndim\n",
    "    if type(x) is not tc.Tensor:\n",
    "        x = tc.from_numpy(x)\n",
    "        \n",
    "    U1 = list()\n",
    "    for n in range(len(U)):\n",
    "        if type(U[n]) is not tc.Tensor:\n",
    "            U1.append(tc.from_numpy(U[n]))\n",
    "        else:\n",
    "            U1.append(U[n])\n",
    "    \n",
    "    ind_x = ''\n",
    "    for n in range(ndim):\n",
    "        ind_x += chr(97 + n)\n",
    "    ind_x1 = ''\n",
    "    for n in range(ndim):\n",
    "        ind_x1 += chr(97 + ndim + n)\n",
    "    contract_eq = copy.deepcopy(ind_x)\n",
    "    for n in range(ndim):\n",
    "        if dim == 0:\n",
    "            contract_eq += ',' + ind_x[n] + ind_x1[n]\n",
    "        else:\n",
    "            contract_eq += ',' + ind_x1[n] + ind_x[n]\n",
    "    contract_eq += '->' + ind_x1\n",
    "    # print(x.shape, U[0].shape, U[1].shape, U[2].shape)\n",
    "    # print(type(contract_eq), contract_eq)\n",
    "    G = tc.einsum(contract_eq, [x] + U1)\n",
    "    G = G.numpy()\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "检查Tucker分解等式是否成立：\n",
      "Tucker分解误差 = 2.1577170916414653e-15\n",
      "\n",
      "检查各个变换矩阵为正交阵：\n",
      "\n",
      "对于第0个变换矩阵，VV.T = \n",
      "[[ 1.00000000e+00  6.51417959e-18  3.45948176e-17]\n",
      " [ 6.51417959e-18  1.00000000e+00 -4.20175670e-17]\n",
      " [ 3.45948176e-17 -4.20175670e-17  1.00000000e+00]]\n",
      "\n",
      "对于第1个变换矩阵，VV.T = \n",
      "[[ 1.00000000e+00 -1.11022302e-16 -6.93889390e-17 -1.11022302e-16]\n",
      " [-1.11022302e-16  1.00000000e+00 -1.11022302e-16 -8.67361738e-17]\n",
      " [-6.93889390e-17 -1.11022302e-16  1.00000000e+00  8.32667268e-17]\n",
      " [-1.11022302e-16 -8.67361738e-17  8.32667268e-17  1.00000000e+00]]\n",
      "\n",
      "对于第2个变换矩阵，VV.T = \n",
      "[[1.00000000e+00 1.57674166e-17]\n",
      " [1.57674166e-17 1.00000000e+00]]\n"
     ]
    }
   ],
   "source": [
    "tensor = np.random.randn(3, 4, 2)\n",
    "Core, V, LM = hosvd(tensor)\n",
    "\n",
    "print('检查Tucker分解等式是否成立：')\n",
    "tensor1 = tucker_product(Core, V, dim=0)\n",
    "error = np.linalg.norm(tensor - tensor1)\n",
    "print('Tucker分解误差 = ' + str(error))\n",
    "\n",
    "print('\\n检查各个变换矩阵为正交阵：')\n",
    "for n, v in enumerate(V):\n",
    "    print('\\n对于第' + str(n) +'个变换矩阵，VV.T = ')\n",
    "    print(v.dot(v.T))\n",
    "\n",
    "# ----------------------------------------------------\n",
    "# 练习：\n",
    "# 1. 编写任意给定四阶张量HOSVD的程序，并于jupyter notebook中任意阶张量的分解结果对比；\n",
    "# 2. 编写程序验证核张量满足的性质。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
